import torch
import torch.nn as nn

from onmt.Utils import aeq, sequence_mask


class GlobalAttention(nn.Module):
	"""
	Global attention takes a matrix and a query vector. It
	then computes a parameterized convex combination of the matrix
	based on the input query.

	Constructs a unit mapping a query `q` of size `dim`
	and a source matrix `H` of size `n x dim`, to an output
	of size `dim`.


	.. mermaid::

	   graph BT
		  A[Query]
		  subgraph RNN
			C[H 1]
			D[H 2]
			E[H N]
		  end
		  F[Attn]
		  G[Output]
		  A --> F
		  C --> F
		  D --> F
		  E --> F
		  C -.-> G
		  D -.-> G
		  E -.-> G
		  F --> G

	All models compute the output as
	:math:`c = \sum_{j=1}^{SeqLength} a_j H_j` where
	:math:`a_j` is the softmax of a score function.
	Then then apply a projection layer to [q, c].

	However they
	differ on how they compute the attention score.

	* Luong Attention (dot, general):
	   * dot: :math:`score(H_j,q) = H_j^T q`
	   * general: :math:`score(H_j, q) = H_j^T W_a q`


	* Bahdanau Attention (mlp):
	   * :math:`score(H_j, q) = v_a^T tanh(W_a q + U_a h_j)`


	Args:
	   dim (int): dimensionality of query and key
	   coverage (bool): use coverage term
	   attn_type (str): type of attention to use, options [dot,general,mlp]

	"""
	def __init__(self, dim, coverage=False, attn_type="dot"):
		super(GlobalAttention, self).__init__()

		self.dim = dim
		self.attn_type = attn_type
		assert (self.attn_type in ["dot", "general", "mlp"]), (
				"Please select a valid attention type.")

		if self.attn_type == "general":
			self.linear_in = nn.Linear(dim, dim, bias=False)
		elif self.attn_type == "mlp":
			self.linear_context = nn.Linear(dim, dim, bias=False)
			self.linear_query = nn.Linear(dim, dim, bias=True)
			self.v = nn.Linear(dim, 1, bias=False)
		# mlp wants it with bias
		out_bias = self.attn_type == "mlp"
		self.linear_out = nn.Linear(dim*2, dim, bias=out_bias)

		self.sm = nn.Softmax()
		self.tanh = nn.Tanh()

		if coverage:
			self.linear_cover = nn.Linear(1, dim, bias=False)

	def score(self, h_t, h_s):
		"""
		Args:
		  h_t (`FloatTensor`): sequence of queries `[batch x tgt_len x dim]`
		  h_s (`FloatTensor`): sequence of sources `[batch x src_len x dim]`

		Returns:
		  :obj:`FloatTensor`:
		   raw attention scores (unnormalized) for each src index
		  `[batch x tgt_len x src_len]`

		"""

		# Check input sizes
		src_batch, src_len, src_dim = h_s.size()
		tgt_batch, tgt_len, tgt_dim = h_t.size()
		aeq(src_batch, tgt_batch)
		aeq(src_dim, tgt_dim)
		aeq(self.dim, src_dim)

		if self.attn_type in ["general", "dot"]:
			if self.attn_type == "general":
				h_t_ = h_t.view(tgt_batch*tgt_len, tgt_dim)
				h_t_ = self.linear_in(h_t_)
				h_t = h_t_.view(tgt_batch, tgt_len, tgt_dim)
			h_s_ = h_s.transpose(1, 2)
			# (batch, t_len, d) x (batch, d, s_len) --> (batch, t_len, s_len)
			return torch.bmm(h_t, h_s_)
		else:
			dim = self.dim
			wq = self.linear_query(h_t.view(-1, dim))
			wq = wq.view(tgt_batch, tgt_len, 1, dim)
			wq = wq.expand(tgt_batch, tgt_len, src_len, dim)

			uh = self.linear_context(h_s.contiguous().view(-1, dim))
			uh = uh.view(src_batch, 1, src_len, dim)
			uh = uh.expand(src_batch, tgt_len, src_len, dim)

			# (batch, t_len, s_len, d)
			wquh = self.tanh(wq + uh)

			return self.v(wquh.view(-1, dim)).view(tgt_batch, tgt_len, src_len)

	def forward(self, input, memory_bank, memory_lengths=None, coverage=None):
		"""

		Args:
		  input (`FloatTensor`): query vectors `[batch x tgt_len x dim]`
		  memory_bank (`FloatTensor`): source vectors `[batch x src_len x dim]`
		  memory_lengths (`LongTensor`): the source context lengths `[batch]`
		  coverage (`FloatTensor`): None (not supported yet)

		Returns:
		  (`FloatTensor`, `FloatTensor`):

		  * Computed vector `[tgt_len x batch x dim]`
		  * Attention distribtutions for each query
			 `[tgt_len x batch x src_len]`
		"""

		# one step input
		if input.dim() == 2:
			one_step = True
			input = input.unsqueeze(1)
		else:
			one_step = False

		batch, sourceL, dim = memory_bank.size()
		batch_, targetL, dim_ = input.size()
		aeq(batch, batch_)
		aeq(dim, dim_)
		aeq(self.dim, dim)
		if coverage is not None:
			batch_, sourceL_ = coverage.size()
			aeq(batch, batch_)
			aeq(sourceL, sourceL_)

		if coverage is not None:
			cover = coverage.view(-1).unsqueeze(1)
			memory_bank += self.linear_cover(cover).view_as(memory_bank)
			memory_bank = self.tanh(memory_bank)

		# compute attention scores, as in Luong et al.
		align = self.score(input, memory_bank)

		if memory_lengths is not None:
			mask = sequence_mask(memory_lengths, sourceL)
			mask = mask.unsqueeze(1)  # Make it broadcastable.
			align.data.masked_fill_(1 - mask, -float('inf'))

		# Softmax to normalize attention weights
		align_vectors = self.sm(align.view(batch*targetL, sourceL))
		align_vectors = align_vectors.view(batch, targetL, sourceL)

		# each context vector c_t is the weighted average
		# over all the source hidden states
		c = torch.bmm(align_vectors, memory_bank)

		# concatenate
		concat_c = torch.cat([c, input], 2).view(batch*targetL, dim*2)
		attn_h = self.linear_out(concat_c).view(batch, targetL, dim)
		if self.attn_type in ["general", "dot"]:
			attn_h = self.tanh(attn_h)

		if one_step:
			attn_h = attn_h.squeeze(1)
			align_vectors = align_vectors.squeeze(1)
			c = c.squeeze(1)

			# Check output sizes
			batch_, dim_ = attn_h.size()
			aeq(batch, batch_)
			aeq(dim, dim_)
			batch_, sourceL_ = align_vectors.size()
			aeq(batch, batch_)
			aeq(sourceL, sourceL_)
		else:
			attn_h = attn_h.transpose(0, 1).contiguous()
			align_vectors = align_vectors.transpose(0, 1).contiguous()
			c = c.transpose(0, 1).contiguous()

			# Check output sizes
			targetL_, batch_, dim_ = attn_h.size()
			aeq(targetL, targetL_)
			aeq(batch, batch_)
			aeq(dim, dim_)
			targetL_, batch_, sourceL_ = align_vectors.size()
			aeq(targetL, targetL_)
			aeq(batch, batch_)
			aeq(sourceL, sourceL_)

		return attn_h, align_vectors, c
